Skip to content

Improve dot lift rewrites #1471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 23, 2025
Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jun 13, 2025

This PR was motivated by the partial jacobian computation example in JAX discussed in jax-ml/jax#5904 (comment)

After #1228 it's actually easier to do this sort of optimization in PyTensor since there's no scan to worry about. We already have a bunch of rewrites to lift subtensor operations through elemwise and dots, but we did not have to lift it through blockwise (and blockwise dot - aka matmul). This PR addresses this.

Some notes on the major changes

  1. Do constant_folding in python mode. This is not related to this PR but I noticed a test was taking 10x longer than the others just because there was a simple constant folding operation being triggered in the rewrites, and the whole c-cache was being loaded. This incurs a one time penalty that's pretty large. For users, not interested in the C backend at all, there's no reason to involve the machinery. One single python eval should be pretty fast anyway. This was moved to Fix CheckAndRaise Op C implementation #1521 as it revealed an unrelated bug

  2. Simplified local_upcast_elemwise. This rewrite was too complex and wasteful, in that it wrapped constants in symbolic expand_dims / alloc + cast. I just do it in numpy directly. This reduces the number of rewrite iterations.

  3. Bunch of improvements to rewrites. Including lifting index operations on the batch dimensions of blockwise, and expanding the dot subtensor lift to work with the Blockwise case. This rewrite predates Blockwise. Others are self-explanatory.

  4. Canonicalize matvec, vecmat, vecdot internally to all use matmul (i.e., Blockwise of 2x2 dot operation). This makes things simpler for our rewrites, because we only need to worry about one case.

  5. The pre-existing test_local_batched_matmul_to_core_matmul rewrite was extend to better address cases of batched matvec, vecmat, and vecdot (batch dimensions are moved to the core dimension). It now moves non-ovelapping batch dimensions of both inputs to their core dimensions. It further tries to avoid reshape (needed when combining multiple batch/core dimensions), so that subtensor_lift rewrites mentioned above can work fine through them.

  6. Prioritize gemv/ger, which also makes several xfail tests pass. There was probably a misattribution mistaken for these xfails.

Benchmark result added in the last commit:
(Note that vectorize=True goes from underperforming (28ms) to overperforming (.37 ms).

Before
------------------------------------------------------------------------------------------------- benchmark: 2 tests ------------------------------------------------------------------------------------------------
Name (time in ms)                                        Min                Max               Mean            StdDev             Median               IQR            Outliers       OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_benchmark_partial_jacobian[vectorize=False]      1.9453 (1.0)       2.8201 (1.0)       2.2296 (1.0)      0.0963 (1.0)       2.2031 (1.0)      0.0855 (1.0)         52;25  448.5095 (1.0)         421           1
test_benchmark_partial_jacobian[vectorize=True]      28.8122 (14.81)    36.9261 (13.09)    34.1470 (15.32)    2.3973 (24.90)    34.8889 (15.84)    2.6797 (31.35)         8;1   29.2851 (0.07)         21           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After
--------------------------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------------------------
Name (time in us)                                           Min                   Max                  Mean             StdDev                Median                IQR            Outliers         OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_benchmark_partial_jacobian[vectorize=True]        345.7980 (1.0)        658.8850 (1.0)        370.9925 (1.0)      41.1362 (1.0)        357.2400 (1.0)      16.9117 (1.0)         24;34  2,695.4724 (1.0)         287           1
test_benchmark_partial_jacobian[vectorize=False]     2,148.9270 (6.21)     3,062.8910 (4.65)     2,215.2234 (5.97)     77.6787 (1.89)     2,194.7940 (6.14)     44.7890 (2.65)        33;34    451.4217 (0.17)        496           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

vectorized jacobian code before:

Subtensor{:stop, :stop} [id A] shape=(5, 5) 9
 ├─ DimShuffle{order=[1,0]} [id B] shape=(1000, 1000) 8
 │  └─ Reshape{3} [id C] shape=(1000, 1000, 1) 7
 │     ├─ Dot22 [id D] shape=(1000, 1000) 6
 │     │  ├─ [[0.903246 ... 74841955]] [id E] shape=(1000, 1000)
 │     │  └─ Reshape{2} [id F] shape=(1000, 1000) 5
 │     │     ├─ True_div [id G] shape=(1000, 1000, 1) 4
 │     │     │  ├─ [[[0.0005] ... [0.0005]]] [id H] shape=(1000, 1000, 1)
 │     │     │  └─ Composite{sqrt((0.001 * i0))} [id I] shape=(1000, 1, 1) 3
 │     │     │     └─ ExpandDims{axes=[1, 2]} [id J] shape=(1000, 1, 1) 2
 │     │     │        └─ CGemv{inplace} [id K] shape=(1000,) 1
 │     │     │           ├─ AllocEmpty{dtype='float64'} [id L] shape=(1000,) 0
 │     │     │           │  └─ 1000 [id M] shape=()
 │     │     │           ├─ 1.0 [id N] shape=()
 │     │     │           ├─ [[0.903246 ... 74841955]] [id O] shape=(1000, 1000)
 │     │     │           ├─ x [id P] shape=(?,)
 │     │     │           └─ 0.0 [id Q] shape=()
 │     │     └─ [1000   -1] [id R] shape=(2,)
 │     └─ [1000 1000    1] [id S] shape=(3,)
 ├─ 5 [id T] shape=()
 └─ 5 [id T] shape=()

and after:

Dot22 [id A] shape=(5, 5) 5
 ├─ True_div [id B] shape=(5, 1000) 4
 │  ├─ [[0.0005 0 ... 0.    ]] [id C] shape=(5, 1000)
 │  └─ Composite{sqrt((0.001 * i0))} [id D] shape=(1, 1000) 3
 │     └─ ExpandDims{axis=0} [id E] shape=(1, 1000) 2
 │        └─ CGemv{inplace} [id F] shape=(1000,) 1
 │           ├─ AllocEmpty{dtype='float64'} [id G] shape=(1000,) 0
 │           │  └─ 1000 [id H] shape=()
 │           ├─ 1.0 [id I] shape=()
 │           ├─ [[0.903246 ... 74841955]] [id J] shape=(1000, 1000)
 │           ├─ x [id K] shape=(?,)
 │           └─ 0.0 [id L] shape=()
 └─ [[0.903246 ... 45926986]] [id M] shape=(1000, 5)

📚 Documentation preview 📚: https://pytensor--1471.org.readthedocs.build/en/1471/

Copy link

codecov bot commented Jul 9, 2025

Codecov Report

Attention: Patch coverage is 92.62673% with 16 lines in your changes missing coverage. Please review.

Project coverage is 81.49%. Comparing base (4d539fa) to head (e17a627).
Report is 12 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/subtensor_lift.py 82.53% 8 Missing and 3 partials ⚠️
pytensor/tensor/math.py 85.71% 1 Missing and 1 partial ⚠️
pytensor/tensor/rewriting/blas.py 50.00% 1 Missing and 1 partial ⚠️
pytensor/tensor/rewriting/math.py 98.68% 0 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1471      +/-   ##
==========================================
- Coverage   81.51%   81.49%   -0.02%     
==========================================
  Files         232      232              
  Lines       53033    53122      +89     
  Branches     9424     9444      +20     
==========================================
+ Hits        43229    43292      +63     
- Misses       7362     7382      +20     
- Partials     2442     2448       +6     
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/elemwise.py 93.37% <100.00%> (+0.65%) ⬆️
pytensor/tensor/rewriting/linalg.py 92.08% <100.00%> (ø)
pytensor/tensor/rewriting/subtensor.py 90.27% <100.00%> (+0.18%) ⬆️
pytensor/tensor/rewriting/math.py 89.27% <98.68%> (-0.36%) ⬇️
pytensor/tensor/math.py 92.78% <85.71%> (-0.20%) ⬇️
pytensor/tensor/rewriting/blas.py 89.28% <50.00%> (-0.36%) ⬇️
pytensor/tensor/rewriting/subtensor_lift.py 91.05% <82.53%> (-1.24%) ⬇️

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR extends and simplifies subtensor lifting and matmul-related rewrites to support Blockwise ops, unifies all matmul variants under _matmul, and adds tests and performance benchmarks for partial Jacobian computations.

  • Extend local_subtensor_of_dot and local_subtensor_of_elemwise to handle batched/blockwise cases and add a new squeeze-based subtensor lift.
  • Unify all matmul-like ops (matvec, vecmat, vecdot, and matrix–matrix) to use a single _matmul core and implement batch‐to‐core‐matmul rewrites with optional reshape.
  • Add new tests for blockwise subtensor lifts, batched matvec rewrites, and partial Jacobian benchmarks; adjust tolerances and seeds for existing tests.

Reviewed Changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/test_gradient.py Import sqrt and add test_benchmark_partial_jacobian
tests/tensor/test_math.py Fix RNG seed and set atol for vector/matrix operation tests
tests/tensor/test_blas.py Remove xfail markers, add skipif and rename parameters
tests/tensor/rewriting/test_subtensor_lift.py Rename subtensor‐of‐elemwise tests, import Op, add blockwise tests
tests/tensor/rewriting/test_math.py Add test_batch_matvec_to_matmul parameterized test
tests/tensor/rewriting/test_blas.py Update imports, skip fast compile mode, adjust rewrite assertions
pytensor/tensor/rewriting/subtensor_lift.py Enhance local_subtensor_of_dot and local_subtensor_of_batch_dims, add squeeze lift
pytensor/tensor/rewriting/subtensor.py Minor cleanup in slice merging and useless‐slice rewrites
pytensor/tensor/rewriting/math.py Replace DimShuffle‐through‐dot rewrite with unified _matmul, reposition specializations
pytensor/tensor/rewriting/linalg.py Update import of _matmul and use in transpose/blockwise rewrites
pytensor/tensor/rewriting/elemwise.py Simplify upcast‐constant rewrite, add register_stabilize
pytensor/tensor/rewriting/blas.py Adjust rewrite positions and batched‐dot reshaping logic
pytensor/tensor/math.py Add dimension check to Dot22.make_node, unify matmul variants
Comments suppressed due to low confidence (1)

tests/tensor/rewriting/test_subtensor_lift.py:194

  • The test references tensor3 but it is not imported; add from pytensor.tensor import tensor3 to the file's imports to avoid a NameError.
        x = tensor3("x", shape=(7, 5, 11), dtype="float64")

@ricardoV94 ricardoV94 force-pushed the dot_lift_rewrite branch 2 times, most recently from e83fe3a to e905280 Compare July 22, 2025 10:40
New rewrite is added to convert unpaired batched row/column matvec or vec products as equivalent matmul products.
The marked xfail test was failing because Ger wasn't introduced, not because of the complex dtype.
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We went over this PR extensively on a call, so this isn't just a snap approve.

@ricardoV94 ricardoV94 merged commit 12213d0 into pymc-devs:main Jul 23, 2025
72 of 73 checks passed
@ricardoV94 ricardoV94 deleted the dot_lift_rewrite branch July 23, 2025 09:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants